%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sbn
import statsmodels.api as sm
from sklearn.preprocessing import OneHotEncoder
%aimport HER2_classifier
class myspace(object):
def __init__(self):
pass
args = myspace()
args.data = ['./HER2_SKBR3_data_6-7-21/']
args.out = ['./output/']
args.drug = ['Neratinib'] # ['Trastuzumab']
args.sensitive_line = ['WT']
args.resistant_line = ['T798I']
args.load = ['normalized'] # ['raw']
args.nclus = [15]
args.resample_sz = [125]
args.burnin = [0]
data, clover_sel, mscarl_sel = HER2_classifier.load_data(args)
print('len selector:', len(clover_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])
data.head()
data.drug.unique()
data.cell_line.unique()
data.mutant.unique()
data, clover_sel, mscarl_sel = HER2_classifier.filter_na(data, args, clover_sel, mscarl_sel)
print('len selector:', len(clover_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])
data.head()
data.groupby(['mutant', 'drug']).count()['track_index'].sort_values().head(15)
data.groupby(['mutant', 'drug']).count()['track_index'].sort_values().tail(15)
plt.hist(data[lambda x: ~x.mutant.isin(['WT', 'T798I', 'ND611'])].groupby(['mutant', 'drug']).count()['track_index'])
clover_sel, mscarl_sel = HER2_classifier.burnin(args, clover_sel, mscarl_sel)
print('len selector:', len(clover_sel))
print('len selector:', len(mscarl_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])
X_train = HER2_classifier.resample(data, args, clover_sel, mscarl_sel)
y_pred, km = HER2_classifier.fit_timeseries_kmeans(args, X_train, plot=True, save=None)
cm, lb = HER2_classifier.quantify_cluster_prop(args, data, y_pred)
cm.shape
HER2_classifier.plot_cluster_corr(cm, save=None)
res, pca = HER2_classifier.reduce_dim(args, cm, lb, plot=True, save=None)
pc_loadings = pd.DataFrame({'clus_feat': range(pca.components_.shape[1]), 'PC1':pca.components_[0], 'PC2':pca.components_[1]})
pc_loadings.head()
plt.figure(figsize=(10,7))
sbn.barplot(x='clus_feat', y='PC1', data=pc_loadings, order=pc_loadings.sort_values(by='PC1').clus_feat)
plt.show()
plt.figure(figsize=(10,7))
sbn.barplot(x='clus_feat', y='PC2', data=pc_loadings, order=pc_loadings.sort_values(by='PC2').clus_feat)
plt.show()
res
batch_res = HER2_classifier.check_batch_effects(args, res, plot=True, save=None)
batch_res.head()
model, accuracy = HER2_classifier.train_classifier(args, res, plot=True, save=None)
prob_res = HER2_classifier.predict_mutants(args, model, res, batch_res)
prob_res
prob_res.tail(10)
plt.figure(figsize=(10,10))
sbn.scatterplot(x='pc1', y='pc2', data=prob_res, hue='prob_res', style='call', s=300)
plt.show()
0.6/0.4
np.log2(0.4/0.6)